{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 34838468.0\n",
      "1 33141592.0\n",
      "2 33985564.0\n",
      "3 31217414.0\n",
      "4 24039622.0\n",
      "5 14898764.0\n",
      "6 8143569.5\n",
      "7 4334779.0\n",
      "8 2517978.0\n",
      "9 1649801.75\n",
      "10 1204558.5\n",
      "11 945184.25\n",
      "12 774128.8125\n",
      "13 649778.375\n",
      "14 553588.8125\n",
      "15 476175.03125\n",
      "16 412483.875\n",
      "17 359298.03125\n",
      "18 314452.71875\n",
      "19 276398.4375\n",
      "20 243851.640625\n",
      "21 215874.78125\n",
      "22 191771.328125\n",
      "23 170882.953125\n",
      "24 152684.609375\n",
      "25 136768.46875\n",
      "26 122816.640625\n",
      "27 110555.78125\n",
      "28 99740.34375\n",
      "29 90173.53125\n",
      "30 81684.5859375\n",
      "31 74128.4453125\n",
      "32 67391.7421875\n",
      "33 61374.90625\n",
      "34 55986.19140625\n",
      "35 51149.875\n",
      "36 46806.2265625\n",
      "37 42899.63671875\n",
      "38 39372.30078125\n",
      "39 36180.7578125\n",
      "40 33290.80078125\n",
      "41 30668.43359375\n",
      "42 28286.7421875\n",
      "43 26118.15625\n",
      "44 24142.62890625\n",
      "45 22339.470703125\n",
      "46 20692.41015625\n",
      "47 19185.560546875\n",
      "48 17805.91015625\n",
      "49 16540.005859375\n",
      "50 15378.306640625\n",
      "51 14310.693359375\n",
      "52 13328.041015625\n",
      "53 12423.91015625\n",
      "54 11590.2216796875\n",
      "55 10820.880859375\n",
      "56 10110.25\n",
      "57 9453.42578125\n",
      "58 8845.7392578125\n",
      "59 8282.779296875\n",
      "60 7760.83642578125\n",
      "61 7276.7138671875\n",
      "62 6827.05322265625\n",
      "63 6409.59326171875\n",
      "64 6021.2001953125\n",
      "65 5659.865234375\n",
      "66 5323.236328125\n",
      "67 5009.48974609375\n",
      "68 4716.978515625\n",
      "69 4443.87158203125\n",
      "70 4188.82666015625\n",
      "71 3950.421875\n",
      "72 3727.47021484375\n",
      "73 3518.889892578125\n",
      "74 3323.48828125\n",
      "75 3140.459716796875\n",
      "76 2968.821044921875\n",
      "77 2807.854736328125\n",
      "78 2656.776123046875\n",
      "79 2514.815185546875\n",
      "80 2381.480224609375\n",
      "81 2256.06591796875\n",
      "82 2138.1416015625\n",
      "83 2027.161376953125\n",
      "84 1922.5999755859375\n",
      "85 1824.1453857421875\n",
      "86 1731.306640625\n",
      "87 1643.77001953125\n",
      "88 1561.208251953125\n",
      "89 1483.25048828125\n",
      "90 1409.652099609375\n",
      "91 1340.10986328125\n",
      "92 1274.3817138671875\n",
      "93 1212.27197265625\n",
      "94 1153.5135498046875\n",
      "95 1097.8907470703125\n",
      "96 1045.2249755859375\n",
      "97 995.361083984375\n",
      "98 948.1497802734375\n",
      "99 903.394287109375\n",
      "100 860.93896484375\n",
      "101 820.6868896484375\n",
      "102 782.5091552734375\n",
      "103 746.289306640625\n",
      "104 711.8989868164062\n",
      "105 679.234619140625\n",
      "106 648.2129516601562\n",
      "107 618.7276611328125\n",
      "108 590.7075805664062\n",
      "109 564.0698852539062\n",
      "110 538.7382202148438\n",
      "111 514.6378173828125\n",
      "112 491.7018127441406\n",
      "113 469.87725830078125\n",
      "114 449.1021728515625\n",
      "115 429.31732177734375\n",
      "116 410.4703369140625\n",
      "117 392.5182189941406\n",
      "118 375.4024353027344\n",
      "119 359.0985107421875\n",
      "120 343.551025390625\n",
      "121 328.72125244140625\n",
      "122 314.581298828125\n",
      "123 301.0873107910156\n",
      "124 288.2192077636719\n",
      "125 275.9346618652344\n",
      "126 264.20513916015625\n",
      "127 253.00869750976562\n",
      "128 242.31851196289062\n",
      "129 232.10733032226562\n",
      "130 222.35833740234375\n",
      "131 213.0459442138672\n",
      "132 204.1505584716797\n",
      "133 195.64695739746094\n",
      "134 187.5193634033203\n",
      "135 179.75010681152344\n",
      "136 172.3186798095703\n",
      "137 165.20811462402344\n",
      "138 158.41207885742188\n",
      "139 151.90943908691406\n",
      "140 145.68804931640625\n",
      "141 139.73297119140625\n",
      "142 134.0347900390625\n",
      "143 128.5826416015625\n",
      "144 123.36618041992188\n",
      "145 118.36921691894531\n",
      "146 113.58551788330078\n",
      "147 109.00410461425781\n",
      "148 104.61421966552734\n",
      "149 100.4097671508789\n",
      "150 96.38236236572266\n",
      "151 92.52302551269531\n",
      "152 88.82571411132812\n",
      "153 85.2813491821289\n",
      "154 81.8853988647461\n",
      "155 78.62787628173828\n",
      "156 75.50669860839844\n",
      "157 72.51366424560547\n",
      "158 69.64488983154297\n",
      "159 66.89266967773438\n",
      "160 64.25312042236328\n",
      "161 61.72319793701172\n",
      "162 59.29533004760742\n",
      "163 56.966575622558594\n",
      "164 54.7320671081543\n",
      "165 52.589256286621094\n",
      "166 50.531497955322266\n",
      "167 48.55730438232422\n",
      "168 46.663204193115234\n",
      "169 44.84599685668945\n",
      "170 43.10108947753906\n",
      "171 41.42528533935547\n",
      "172 39.81774139404297\n",
      "173 38.27351760864258\n",
      "174 36.7920036315918\n",
      "175 35.36980056762695\n",
      "176 34.00321578979492\n",
      "177 32.6913948059082\n",
      "178 31.430967330932617\n",
      "179 30.221912384033203\n",
      "180 29.05903434753418\n",
      "181 27.942394256591797\n",
      "182 26.869720458984375\n",
      "183 25.839725494384766\n",
      "184 24.85027503967285\n",
      "185 23.89914894104004\n",
      "186 22.985315322875977\n",
      "187 22.108097076416016\n",
      "188 21.264678955078125\n",
      "189 20.453506469726562\n",
      "190 19.67483139038086\n",
      "191 18.92633819580078\n",
      "192 18.20647430419922\n",
      "193 17.51482391357422\n",
      "194 16.850736618041992\n",
      "195 16.211938858032227\n",
      "196 15.597586631774902\n",
      "197 15.006938934326172\n",
      "198 14.439468383789062\n",
      "199 13.893719673156738\n",
      "200 13.36877727508545\n",
      "201 12.864436149597168\n",
      "202 12.379522323608398\n",
      "203 11.912912368774414\n",
      "204 11.464310646057129\n",
      "205 11.033108711242676\n",
      "206 10.618121147155762\n",
      "207 10.219304084777832\n",
      "208 9.835758209228516\n",
      "209 9.4669189453125\n",
      "210 9.112168312072754\n",
      "211 8.770597457885742\n",
      "212 8.44197940826416\n",
      "213 8.126016616821289\n",
      "214 7.822319030761719\n",
      "215 7.529946327209473\n",
      "216 7.248904705047607\n",
      "217 6.978030681610107\n",
      "218 6.717903137207031\n",
      "219 6.467252254486084\n",
      "220 6.226631164550781\n",
      "221 5.994656562805176\n",
      "222 5.77148962020874\n",
      "223 5.556979656219482\n",
      "224 5.350459575653076\n",
      "225 5.151628494262695\n",
      "226 4.960315704345703\n",
      "227 4.776061534881592\n",
      "228 4.599002838134766\n",
      "229 4.428690433502197\n",
      "230 4.264365196228027\n",
      "231 4.106524467468262\n",
      "232 3.954587459564209\n",
      "233 3.8083302974700928\n",
      "234 3.6674869060516357\n",
      "235 3.531869888305664\n",
      "236 3.4015560150146484\n",
      "237 3.2758138179779053\n",
      "238 3.154892921447754\n",
      "239 3.038429021835327\n",
      "240 2.9265005588531494\n",
      "241 2.818514823913574\n",
      "242 2.7148404121398926\n",
      "243 2.6147611141204834\n",
      "244 2.5184597969055176\n",
      "245 2.425811767578125\n",
      "246 2.33675217628479\n",
      "247 2.2508866786956787\n",
      "248 2.1679904460906982\n",
      "249 2.088416814804077\n",
      "250 2.011852741241455\n",
      "251 1.9380263090133667\n",
      "252 1.8668866157531738\n",
      "253 1.798403024673462\n",
      "254 1.7324905395507812\n",
      "255 1.6690473556518555\n",
      "256 1.6079224348068237\n",
      "257 1.5490334033966064\n",
      "258 1.4923211336135864\n",
      "259 1.4376897811889648\n",
      "260 1.3850973844528198\n",
      "261 1.3344584703445435\n",
      "262 1.2856464385986328\n",
      "263 1.238754153251648\n",
      "264 1.1935430765151978\n",
      "265 1.1500052213668823\n",
      "266 1.1079671382904053\n",
      "267 1.0675686597824097\n",
      "268 1.0286755561828613\n",
      "269 0.9910664558410645\n",
      "270 0.9549412727355957\n",
      "271 0.9201943278312683\n",
      "272 0.886725664138794\n",
      "273 0.8543822169303894\n",
      "274 0.8232855200767517\n",
      "275 0.7933573126792908\n",
      "276 0.7645137906074524\n",
      "277 0.7367044687271118\n",
      "278 0.7099923491477966\n",
      "279 0.6841153502464294\n",
      "280 0.6592459678649902\n",
      "281 0.635317325592041\n",
      "282 0.6122614145278931\n",
      "283 0.5900916457176208\n",
      "284 0.5685697197914124\n",
      "285 0.5479152798652649\n",
      "286 0.5280589461326599\n",
      "287 0.5089750289916992\n",
      "288 0.4905554950237274\n",
      "289 0.472759872674942\n",
      "290 0.4556622803211212\n",
      "291 0.4391653537750244\n",
      "292 0.4232412576675415\n",
      "293 0.4079386293888092\n",
      "294 0.3932224214076996\n",
      "295 0.37897971272468567\n",
      "296 0.3652483820915222\n",
      "297 0.3520159125328064\n",
      "298 0.3393038809299469\n",
      "299 0.3270008862018585\n",
      "300 0.3152548670768738\n",
      "301 0.30385342240333557\n",
      "302 0.2928246557712555\n",
      "303 0.2822943329811096\n",
      "304 0.2720912992954254\n",
      "305 0.2622653841972351\n",
      "306 0.2528100907802582\n",
      "307 0.24370644986629486\n",
      "308 0.2349167913198471\n",
      "309 0.22648417949676514\n",
      "310 0.21827200055122375\n",
      "311 0.2104230374097824\n",
      "312 0.2028466761112213\n",
      "313 0.19551461935043335\n",
      "314 0.18849556148052216\n",
      "315 0.18169602751731873\n",
      "316 0.17520159482955933\n",
      "317 0.16889818012714386\n",
      "318 0.16282296180725098\n",
      "319 0.15692870318889618\n",
      "320 0.15129156410694122\n",
      "321 0.14583589136600494\n",
      "322 0.14062601327896118\n",
      "323 0.13560402393341064\n",
      "324 0.13071848452091217\n",
      "325 0.12600745260715485\n",
      "326 0.12150520086288452\n",
      "327 0.11713986098766327\n",
      "328 0.11292973905801773\n",
      "329 0.10891234129667282\n",
      "330 0.10497201979160309\n",
      "331 0.10120904445648193\n",
      "332 0.09758470952510834\n",
      "333 0.09408371150493622\n",
      "334 0.09073255211114883\n",
      "335 0.08748523890972137\n",
      "336 0.08433917164802551\n",
      "337 0.08130931109189987\n",
      "338 0.07841827720403671\n",
      "339 0.07563333958387375\n",
      "340 0.0729256123304367\n",
      "341 0.07031291723251343\n",
      "342 0.0678146705031395\n",
      "343 0.06538737565279007\n",
      "344 0.063047394156456\n",
      "345 0.0607999786734581\n",
      "346 0.058629851788282394\n",
      "347 0.05653411149978638\n",
      "348 0.054513778537511826\n",
      "349 0.052576784044504166\n",
      "350 0.050706181675195694\n",
      "351 0.04890399053692818\n",
      "352 0.04715315252542496\n",
      "353 0.04547734931111336\n",
      "354 0.04385605826973915\n",
      "355 0.04231569170951843\n",
      "356 0.040807854384183884\n",
      "357 0.03934280201792717\n",
      "358 0.03795529529452324\n",
      "359 0.036607421934604645\n",
      "360 0.03531341627240181\n",
      "361 0.034060973674058914\n",
      "362 0.032853465527296066\n",
      "363 0.03168454393744469\n",
      "364 0.030570708215236664\n",
      "365 0.02948746271431446\n",
      "366 0.028439929708838463\n",
      "367 0.027434049174189568\n",
      "368 0.026473769918084145\n",
      "369 0.025539135560393333\n",
      "370 0.024630391970276833\n",
      "371 0.023760443553328514\n",
      "372 0.022917339578270912\n",
      "373 0.022114001214504242\n",
      "374 0.021344952285289764\n",
      "375 0.020591987296938896\n",
      "376 0.019866283982992172\n",
      "377 0.01916738599538803\n",
      "378 0.018495794385671616\n",
      "379 0.017842840403318405\n",
      "380 0.01721605844795704\n",
      "381 0.016614731401205063\n",
      "382 0.01603921875357628\n",
      "383 0.015477508306503296\n",
      "384 0.014940518885850906\n",
      "385 0.014416388235986233\n",
      "386 0.013913419097661972\n",
      "387 0.013429778628051281\n",
      "388 0.012963760644197464\n",
      "389 0.012515738606452942\n",
      "390 0.012075887992978096\n",
      "391 0.011660278774797916\n",
      "392 0.011267208494246006\n",
      "393 0.010874430648982525\n",
      "394 0.01050148531794548\n",
      "395 0.010149220004677773\n",
      "396 0.009795350953936577\n",
      "397 0.009464341215789318\n",
      "398 0.009134096093475819\n",
      "399 0.00882729422301054\n",
      "400 0.008524206466972828\n",
      "401 0.00823886413127184\n",
      "402 0.007954451255500317\n",
      "403 0.007691939361393452\n",
      "404 0.007424088194966316\n",
      "405 0.007173002231866121\n",
      "406 0.0069321137852966785\n",
      "407 0.006700365804135799\n",
      "408 0.006480945739895105\n",
      "409 0.006260807625949383\n",
      "410 0.006060661282390356\n",
      "411 0.005856854375451803\n",
      "412 0.005659427959471941\n",
      "413 0.0054711103439331055\n",
      "414 0.005297867581248283\n",
      "415 0.005125643219798803\n",
      "416 0.0049557751044631\n",
      "417 0.004796652123332024\n",
      "418 0.004635329358279705\n",
      "419 0.00448473310098052\n",
      "420 0.0043402304872870445\n",
      "421 0.004202606622129679\n",
      "422 0.00406738743185997\n",
      "423 0.003938395995646715\n",
      "424 0.0038121924735605717\n",
      "425 0.0036872588098049164\n",
      "426 0.003572269110009074\n",
      "427 0.0034584186505526304\n",
      "428 0.003352078376337886\n",
      "429 0.003247129265218973\n",
      "430 0.0031457652803510427\n",
      "431 0.0030436681117862463\n",
      "432 0.0029485125560313463\n",
      "433 0.0028590415604412556\n",
      "434 0.0027684206143021584\n",
      "435 0.002683203434571624\n",
      "436 0.0025994658935815096\n",
      "437 0.0025183509569615126\n",
      "438 0.002443921286612749\n",
      "439 0.002367061795666814\n",
      "440 0.002296499442309141\n",
      "441 0.002228411380201578\n",
      "442 0.0021621924825012684\n",
      "443 0.002096005482599139\n",
      "444 0.002032945631071925\n",
      "445 0.001975882798433304\n",
      "446 0.0019150436855852604\n",
      "447 0.0018562431214377284\n",
      "448 0.001803578226827085\n",
      "449 0.001751550124026835\n",
      "450 0.0017019915394484997\n",
      "451 0.0016535762697458267\n",
      "452 0.0016033551655709743\n",
      "453 0.0015580898616462946\n",
      "454 0.0015138316666707397\n",
      "455 0.001470763934776187\n",
      "456 0.0014283419586718082\n",
      "457 0.0013895528391003609\n",
      "458 0.0013509091222658753\n",
      "459 0.0013134789187461138\n",
      "460 0.0012777299853041768\n",
      "461 0.0012439462589100003\n",
      "462 0.0012099341256543994\n",
      "463 0.001178003498353064\n",
      "464 0.0011446793796494603\n",
      "465 0.0011149568017572165\n",
      "466 0.0010865137446671724\n",
      "467 0.0010563313262537122\n",
      "468 0.0010283751180395484\n",
      "469 0.0010014176368713379\n",
      "470 0.0009750793687999249\n",
      "471 0.0009507556678727269\n",
      "472 0.0009257718338631094\n",
      "473 0.000901735620573163\n",
      "474 0.0008794458699412644\n",
      "475 0.0008576472755521536\n",
      "476 0.0008356127655133605\n",
      "477 0.000814672268461436\n",
      "478 0.0007964319665916264\n",
      "479 0.0007748631178401411\n",
      "480 0.0007570949383080006\n",
      "481 0.0007368183578364551\n",
      "482 0.0007191302138380706\n",
      "483 0.0007003776845522225\n",
      "484 0.0006857088301330805\n",
      "485 0.0006688642897643149\n",
      "486 0.0006528224912472069\n",
      "487 0.0006381114944815636\n",
      "488 0.0006220077048055828\n",
      "489 0.0006093481206335127\n",
      "490 0.0005934870569035411\n",
      "491 0.0005786480032838881\n",
      "492 0.0005662408075295389\n",
      "493 0.0005527993780560791\n",
      "494 0.0005422201938927174\n",
      "495 0.0005303858197294176\n",
      "496 0.0005179387517273426\n",
      "497 0.000506275799125433\n",
      "498 0.00049491913523525\n",
      "499 0.00048322349903173745\n"
     ]
    }
   ],
   "source": [
    "device = torch.device('cpu')\n",
    "N, D_in, H, D_out = 64, 1000, 100, 10\n",
    "\n",
    "x = torch.randn(N, D_in)\n",
    "y = torch.randn(N, D_out)\n",
    "\n",
    "w1 = torch.randn(D_in, H)\n",
    "w2 = torch.randn(H, D_out)\n",
    "\n",
    "learning_rate = 1e-6\n",
    "\n",
    "for t in range(500):\n",
    "    h = torch.mm(x, w1)\n",
    "    h_relu = torch.clamp(h, min=0)\n",
    "    y_pred = torch.mm(h_relu, w2)\n",
    "    \n",
    "    loss = (y_pred - y).pow(2).sum()\n",
    "    print(t, loss.item())\n",
    "    \n",
    "    grad_y_pred = 2.0 * (y_pred - y)\n",
    "    grad_w2 = torch.mm(h_relu.t(), grad_y_pred)\n",
    "    grad_h_relu = torch.mm(grad_y_pred, w2.t())\n",
    "    grad_h = grad_h_relu.clone()\n",
    "    grad_h[h < 0] = 0\n",
    "    grad_w1 = torch.mm(x.t(), grad_h)\n",
    "    \n",
    "    w1 -= learning_rate * grad_w1\n",
    "    w2 -= learning_rate * grad_w2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
